#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
long long n,m,c;
int main()
{
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	scanf("%lld%lld",&n,&m);
	c=n-1+(m-1)*n;
	printf("%lld",c%998244353);
	return 0;
}
